{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Pruning Tutorial\n**Author**: [Michela Paganini](https://github.com/mickypaganini)\n\nState-of-the-art deep learning techniques rely on over-parametrized models \nthat are hard to deploy. On the contrary, biological neural networks are \nknown to use efficient sparse connectivity. Identifying optimal  \ntechniques to compress models by reducing the number of parameters in them is \nimportant in order to reduce memory, battery, and hardware consumption without \nsacrificing accuracy. This in turn allows you to deploy lightweight models on device, and guarantee \nprivacy with private on-device computation. On the research front, pruning is \nused to investigate the differences in learning dynamics between \nover-parametrized and under-parametrized networks, to study the role of lucky \nsparse subnetworks and initializations\n(\"[lottery tickets](https://arxiv.org/abs/1803.03635)\") as a destructive \nneural architecture search technique, and more.\n\nIn this tutorial, you will learn how to use ``torch.nn.utils.prune`` to \nsparsify your neural networks, and how to extend it to implement your \nown custom pruning technique.\n\n## Requirements\n``\"torch>=1.4.0a0+8e8a5e0\"``\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import torch\nfrom torch import nn\nimport torch.nn.utils.prune as prune\nimport torch.nn.functional as F"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Create a model\n\nIn this tutorial, we use the [LeNet](http://yann.lecun.com/exdb/publis/pdf/lecun-98.pdf) architecture from \nLeCun et al., 1998.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\nclass LeNet(nn.Module):\n    def __init__(self):\n        super(LeNet, self).__init__()\n        # 1 input image channel, 6 output channels, 3x3 square conv kernel\n        self.conv1 = nn.Conv2d(1, 6, 3)\n        self.conv2 = nn.Conv2d(6, 16, 3)\n        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5x5 image dimension\n        self.fc2 = nn.Linear(120, 84)\n        self.fc3 = nn.Linear(84, 10)\n\n    def forward(self, x):\n        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))\n        x = F.max_pool2d(F.relu(self.conv2(x)), 2)\n        x = x.view(-1, int(x.nelement() / x.shape[0]))\n        x = F.relu(self.fc1(x))\n        x = F.relu(self.fc2(x))\n        x = self.fc3(x)\n        return x\n\nmodel = LeNet().to(device=device)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Inspect a Module\n\nLet's inspect the (unpruned) ``conv1`` layer in our LeNet model. It will contain two \nparameters ``weight`` and ``bias``, and no buffers, for now.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "module = model.conv1\nprint(list(module.named_parameters()))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "print(list(module.named_buffers()))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Pruning a Module\n\nTo prune a module (in this example, the ``conv1`` layer of our LeNet \narchitecture), first select a pruning technique among those available in \n``torch.nn.utils.prune`` (or\n[implement](#extending-torch-nn-utils-pruning-with-custom-pruning-functions)\nyour own by subclassing \n``BasePruningMethod``). Then, specify the module and the name of the parameter to \nprune within that module. Finally, using the adequate keyword arguments \nrequired by the selected pruning technique, specify the pruning parameters.\n\nIn this example, we will prune at random 30% of the connections in \nthe parameter named ``weight`` in the ``conv1`` layer.\nThe module is passed as the first argument to the function; ``name`` \nidentifies the parameter within that module using its string identifier; and \n``amount`` indicates either the percentage of connections to prune (if it \nis a float between 0. and 1.), or the absolute number of connections to \nprune (if it is a non-negative integer).\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "prune.random_unstructured(module, name=\"weight\", amount=0.3)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Pruning acts by removing ``weight`` from the parameters and replacing it with \na new parameter called ``weight_orig`` (i.e. appending ``\"_orig\"`` to the \ninitial parameter ``name``). ``weight_orig`` stores the unpruned version of \nthe tensor. The ``bias`` was not pruned, so it will remain intact.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "print(list(module.named_parameters()))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "The pruning mask generated by the pruning technique selected above is saved \nas a module buffer named ``weight_mask`` (i.e. appending ``\"_mask\"`` to the \ninitial parameter ``name``).\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "print(list(module.named_buffers()))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "For the forward pass to work without modification, the ``weight`` attribute \nneeds to exist. The pruning techniques implemented in \n``torch.nn.utils.prune`` compute the pruned version of the weight (by \ncombining the mask with the original parameter) and store them in the \nattribute ``weight``. Note, this is no longer a parameter of the ``module``,\nit is now simply an attribute.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "print(module.weight)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Finally, pruning is applied prior to each forward pass using PyTorch's\n``forward_pre_hooks``. Specifically, when the ``module`` is pruned, as we \nhave done here, it will acquire a ``forward_pre_hook`` for each parameter \nassociated with it that gets pruned. In this case, since we have so far \nonly pruned the original parameter named ``weight``, only one hook will be\npresent.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "print(module._forward_pre_hooks)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "For completeness, we can now prune the ``bias`` too, to see how the \nparameters, buffers, hooks, and attributes of the ``module`` change.\nJust for the sake of trying out another pruning technique, here we prune the \n3 smallest entries in the bias by L1 norm, as implemented in the \n``l1_unstructured`` pruning function.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "prune.l1_unstructured(module, name=\"bias\", amount=3)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "We now expect the named parameters to include both ``weight_orig`` (from \nbefore) and ``bias_orig``. The buffers will include ``weight_mask`` and \n``bias_mask``. The pruned versions of the two tensors will exist as \nmodule attributes, and the module will now have two ``forward_pre_hooks``.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "print(list(module.named_parameters()))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "print(list(module.named_buffers()))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "print(module.bias)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "print(module._forward_pre_hooks)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Iterative Pruning\n\nThe same parameter in a module can be pruned multiple times, with the \neffect of the various pruning calls being equal to the combination of the\nvarious masks applied in series.\nThe combination of a new mask with the old mask is handled by the \n``PruningContainer``'s ``compute_mask`` method.\n\nSay, for example, that we now want to further prune ``module.weight``, this\ntime using structured pruning along the 0th axis of the tensor (the 0th axis \ncorresponds to the output channels of the convolutional layer and has \ndimensionality 6 for ``conv1``), based on the channels' L2 norm. This can be \nachieved using the ``ln_structured`` function, with ``n=2`` and ``dim=0``.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "prune.ln_structured(module, name=\"weight\", amount=0.5, n=2, dim=0)\n\n# As we can verify, this will zero out all the connections corresponding to \n# 50% (3 out of 6) of the channels, while preserving the action of the \n# previous mask.\nprint(module.weight)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "The corresponding hook will now be of type \n``torch.nn.utils.prune.PruningContainer``, and will store the history of \npruning applied to the ``weight`` parameter.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "for hook in module._forward_pre_hooks.values():\n    if hook._tensor_name == \"weight\":  # select out the correct hook\n        break\n\nprint(list(hook))  # pruning history in the container"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Serializing a pruned model\nAll relevant tensors, including the mask buffers and the original parameters\nused to compute the pruned tensors are stored in the model's ``state_dict`` \nand can therefore be easily serialized and saved, if needed.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "print(model.state_dict().keys())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Remove pruning re-parametrization\n\nTo make the pruning permanent, remove the re-parametrization in terms\nof ``weight_orig`` and ``weight_mask``, and remove the ``forward_pre_hook``,\nwe can use the ``remove`` functionality from ``torch.nn.utils.prune``.\nNote that this doesn't undo the pruning, as if it never happened. It simply \nmakes it permanent, instead, by reassigning the parameter ``weight`` to the \nmodel parameters, in its pruned version.\n\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Prior to removing the re-parametrization:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "print(list(module.named_parameters()))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "print(list(module.named_buffers()))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "print(module.weight)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "After removing the re-parametrization:\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "prune.remove(module, 'weight')\nprint(list(module.named_parameters()))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "print(list(module.named_buffers()))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Pruning multiple parameters in a model \n\nBy specifying the desired pruning technique and parameters, we can easily \nprune multiple tensors in a network, perhaps according to their type, as we \nwill see in this example.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "new_model = LeNet()\nfor name, module in new_model.named_modules():\n    # prune 20% of connections in all 2D-conv layers \n    if isinstance(module, torch.nn.Conv2d):\n        prune.l1_unstructured(module, name='weight', amount=0.2)\n    # prune 40% of connections in all linear layers \n    elif isinstance(module, torch.nn.Linear):\n        prune.l1_unstructured(module, name='weight', amount=0.4)\n\nprint(dict(new_model.named_buffers()).keys())  # to verify that all masks exist"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Global pruning\n\nSo far, we only looked at what is usually referred to as \"local\" pruning,\ni.e. the practice of pruning tensors in a model one by one, by \ncomparing the statistics (weight magnitude, activation, gradient, etc.) of \neach entry exclusively to the other entries in that tensor. However, a \ncommon and perhaps more powerful technique is to prune the model all at \nonce, by removing (for example) the lowest 20% of connections across the \nwhole model, instead of removing the lowest 20% of connections in each \nlayer. This is likely to result in different pruning percentages per layer.\nLet's see how to do that using ``global_unstructured`` from \n``torch.nn.utils.prune``.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "model = LeNet()\n\nparameters_to_prune = (\n    (model.conv1, 'weight'),\n    (model.conv2, 'weight'),\n    (model.fc1, 'weight'),\n    (model.fc2, 'weight'),\n    (model.fc3, 'weight'),\n)\n\nprune.global_unstructured(\n    parameters_to_prune,\n    pruning_method=prune.L1Unstructured,\n    amount=0.2,\n)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Now we can check the sparsity induced in every pruned parameter, which will \nnot be equal to 20% in each layer. However, the global sparsity will be \n(approximately) 20%.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "print(\n    \"Sparsity in conv1.weight: {:.2f}%\".format(\n        100. * float(torch.sum(model.conv1.weight == 0))\n        / float(model.conv1.weight.nelement())\n    )\n)\nprint(\n    \"Sparsity in conv2.weight: {:.2f}%\".format(\n        100. * float(torch.sum(model.conv2.weight == 0))\n        / float(model.conv2.weight.nelement())\n    )\n)\nprint(\n    \"Sparsity in fc1.weight: {:.2f}%\".format(\n        100. * float(torch.sum(model.fc1.weight == 0))\n        / float(model.fc1.weight.nelement())\n    )\n)\nprint(\n    \"Sparsity in fc2.weight: {:.2f}%\".format(\n        100. * float(torch.sum(model.fc2.weight == 0))\n        / float(model.fc2.weight.nelement())\n    )\n)\nprint(\n    \"Sparsity in fc3.weight: {:.2f}%\".format(\n        100. * float(torch.sum(model.fc3.weight == 0))\n        / float(model.fc3.weight.nelement())\n    )\n)\nprint(\n    \"Global sparsity: {:.2f}%\".format(\n        100. * float(\n            torch.sum(model.conv1.weight == 0)\n            + torch.sum(model.conv2.weight == 0)\n            + torch.sum(model.fc1.weight == 0)\n            + torch.sum(model.fc2.weight == 0)\n            + torch.sum(model.fc3.weight == 0)\n        )\n        / float(\n            model.conv1.weight.nelement()\n            + model.conv2.weight.nelement()\n            + model.fc1.weight.nelement()\n            + model.fc2.weight.nelement()\n            + model.fc3.weight.nelement()\n        )\n    )\n)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Extending ``torch.nn.utils.prune`` with custom pruning functions\nTo implement your own pruning function, you can extend the\n``nn.utils.prune`` module by subclassing the ``BasePruningMethod``\nbase class, the same way all other pruning methods do. The base class\nimplements the following methods for you: ``__call__``, ``apply_mask``,\n``apply``, ``prune``, and ``remove``. Beyond some special cases, you shouldn't\nhave to reimplement these methods for your new pruning technique.\nYou will, however, have to implement ``__init__`` (the constructor),\nand ``compute_mask`` (the instructions on how to compute the mask\nfor the given tensor according to the logic of your pruning\ntechnique). In addition, you will have to specify which type of\npruning this technique implements (supported options are ``global``,\n``structured``, and ``unstructured``). This is needed to determine\nhow to combine masks in the case in which pruning is applied\niteratively. In other words, when pruning a pre-pruned parameter,\nthe current prunining techique is expected to act on the unpruned\nportion of the parameter. Specifying the ``PRUNING_TYPE`` will\nenable the ``PruningContainer`` (which handles the iterative\napplication of pruning masks) to correctly identify the slice of the\nparameter to prune.\n\nLet's assume, for example, that you want to implement a pruning\ntechnique that prunes every other entry in a tensor (or -- if the\ntensor has previously been pruned -- in the remaining unpruned\nportion of the tensor). This will be of ``PRUNING_TYPE='unstructured'``\nbecause it acts on individual connections in a layer and not on entire\nunits/channels (``'structured'``), or across different parameters\n(``'global'``).\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "class FooBarPruningMethod(prune.BasePruningMethod):\n    \"\"\"Prune every other entry in a tensor\n    \"\"\"\n    PRUNING_TYPE = 'unstructured'\n\n    def compute_mask(self, t, default_mask):\n        mask = default_mask.clone()\n        mask.view(-1)[::2] = 0 \n        return mask"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Now, to apply this to a parameter in an ``nn.Module``, you should\nalso provide a simple function that instantiates the method and\napplies it.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "def foobar_unstructured(module, name):\n    \"\"\"Prunes tensor corresponding to parameter called `name` in `module`\n    by removing every other entry in the tensors.\n    Modifies module in place (and also return the modified module) \n    by:\n    1) adding a named buffer called `name+'_mask'` corresponding to the \n    binary mask applied to the parameter `name` by the pruning method.\n    The parameter `name` is replaced by its pruned version, while the \n    original (unpruned) parameter is stored in a new parameter named \n    `name+'_orig'`.\n\n    Args:\n        module (nn.Module): module containing the tensor to prune\n        name (string): parameter name within `module` on which pruning\n                will act.\n\n    Returns:\n        module (nn.Module): modified (i.e. pruned) version of the input\n            module\n    \n    Examples:\n        >>> m = nn.Linear(3, 4)\n        >>> foobar_unstructured(m, name='bias')\n    \"\"\"\n    FooBarPruningMethod.apply(module, name)\n    return module"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Let's try it out!\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "model = LeNet()\nfoobar_unstructured(model.fc3, name='bias')\n\nprint(model.fc3.bias_mask)"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.10.4"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}